{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebeca6b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "# hack to allow argparse to work in notebook\n",
    "sys.argv = [\"main.py\"]\n",
    "\n",
    "import os\n",
    "import time\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import timm\n",
    "import albumentations as A\n",
    "from albumentations.pytorch import ToTensorV2\n",
    "\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.metrics import roc_auc_score, confusion_matrix\n",
    "\n",
    "import cv2\n",
    "import argparse\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--debug', action='store_true', help='Run in debug mode')\n",
    "args = parser.parse_args()\n",
    "DEBUG = args.debug\n",
    "\n",
    "SEED = 2024\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "TRAIN_DIR = './workspace_input/train/'\n",
    "TEST_DIR = './workspace_input/test/'\n",
    "TRAIN_CSV = './workspace_input/train.csv'\n",
    "SAMPLE_SUB_PATH = './workspace_input/sample_submission.csv'\n",
    "MODEL_DIR = 'models/'\n",
    "os.makedirs(MODEL_DIR, exist_ok=True)\n",
    "\n",
    "class CactusDataset(Dataset):\n",
    "    def __init__(self, image_ids, labels=None, id2path=None, transforms=None):\n",
    "        self.image_ids = image_ids\n",
    "        self.labels = labels\n",
    "        self.id2path = id2path\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_ids)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_id = self.image_ids[idx]\n",
    "        img_path = self.id2path[img_id]\n",
    "        image = cv2.imread(img_path)\n",
    "        if image is None:\n",
    "            raise RuntimeError(f\"Cannot read image at {img_path}\")\n",
    "        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "        if self.transforms:\n",
    "            augmented = self.transforms(image=image)\n",
    "            image = augmented[\"image\"]\n",
    "        if self.labels is not None:\n",
    "            label = self.labels[idx]\n",
    "            return image, label, img_id\n",
    "        else:\n",
    "            return image, img_id\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9086e8dc",
   "metadata": {},
   "source": [
    "## Data Loading and Preprocessing\n",
    "This section loads the train and test data, performs EDA, and prepares the dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05509a31",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_class_weight(y):\n",
    "    counts = np.bincount(y)\n",
    "    if len(counts) < 2:\n",
    "        counts = np.pad(counts, (0, 2-len(counts)), constant_values=0)\n",
    "    n_pos, n_neg = counts[1], counts[0]\n",
    "    total = n_pos + n_neg\n",
    "    minority, majority = min(n_pos, n_neg), max(n_pos, n_neg)\n",
    "    ratio = majority / (minority + 1e-10)\n",
    "    need_weights = ratio > 2\n",
    "    weights = None\n",
    "    if need_weights:\n",
    "        inv_freq = [1 / (n_neg + 1e-10), 1 / (n_pos + 1e-10)]\n",
    "        s = sum(inv_freq)\n",
    "        weights = [w / s * 2 for w in inv_freq]\n",
    "    return weights, n_pos, n_neg, ratio, need_weights\n",
    "\n",
    "def print_eda(train_df):\n",
    "    print(\"=== Start of EDA part ===\")\n",
    "    print(\"Shape of train.csv:\", train_df.shape)\n",
    "    print(\"First 5 rows:\\n\", train_df.head())\n",
    "    print(\"Column data types:\\n\", train_df.dtypes)\n",
    "    print(\"Missing values per column:\\n\", train_df.isnull().sum())\n",
    "    print(\"Unique values per column:\")\n",
    "    for col in train_df.columns:\n",
    "        print(f\" - {col}: {train_df[col].nunique()}\")\n",
    "    label_counts = train_df['has_cactus'].value_counts()\n",
    "    print(\"Label distribution (has_cactus):\")\n",
    "    print(label_counts)\n",
    "    pos, neg = label_counts.get(1, 0), label_counts.get(0, 0)\n",
    "    total = pos + neg\n",
    "    if total > 0:\n",
    "        print(f\"  Positive:Negative ratio: {pos}:{neg} ({pos/total:.3f}:{neg/total:.3f})\")\n",
    "        print(f\"  Percentage positive: {pos/total*100:.2f}%\")\n",
    "    else:\n",
    "        print(\"  No data found.\")\n",
    "    print(\"Image filename examples:\", train_df['id'].unique()[:5])\n",
    "    print(\"=== End of EDA part ===\")\n",
    "\n",
    "print(\"Section: Data Loading and Preprocessing\")\n",
    "try:\n",
    "    train_df = pd.read_csv(TRAIN_CSV)\n",
    "except Exception as e:\n",
    "    print(f\"Failed to load train.csv: {e}\")\n",
    "    sys.exit(1)\n",
    "print_eda(train_df)\n",
    "\n",
    "train_id2path = {img_id: os.path.join(TRAIN_DIR, img_id) for img_id in train_df['id']}\n",
    "try:\n",
    "    sample_sub = pd.read_csv(SAMPLE_SUB_PATH)\n",
    "except Exception as e:\n",
    "    print(f\"Failed to load sample_submission.csv: {e}\")\n",
    "    sys.exit(1)\n",
    "test_img_ids = list(sample_sub['id'])\n",
    "test_id2path = {img_id: os.path.join(TEST_DIR, img_id) for img_id in test_img_ids}\n",
    "print(f\"Loaded {len(train_id2path)} train images, {len(test_id2path)} test images.\")\n",
    "\n",
    "y_train = train_df['has_cactus'].values\n",
    "class_weights, n_pos, n_neg, imbalance_ratio, need_weights = compute_class_weight(y_train)\n",
    "print(f\"Class stats: Pos={n_pos}, Neg={n_neg}, Imbalance Ratio(majority/minority)={imbalance_ratio:.3f}\")\n",
    "print(f\"Use class weights: {need_weights}, Class weights: {class_weights if class_weights is not None else '[1.0,1.0]'}\")\n",
    "if class_weights is not None:\n",
    "    np.save(os.path.join(MODEL_DIR, \"class_weights.npy\"), class_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b201cd3f",
   "metadata": {},
   "source": [
    "## Feature Engineering\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d4697e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Section: Feature Engineering\")\n",
    "train_df = train_df.copy()\n",
    "cv_fold = 5\n",
    "skf = StratifiedKFold(n_splits=cv_fold, shuffle=True, random_state=SEED)\n",
    "folds = np.zeros(len(train_df), dtype=np.int32)\n",
    "for idx, (_, val_idx) in enumerate(skf.split(train_df['id'], train_df['has_cactus'])):\n",
    "    folds[val_idx] = idx\n",
    "train_df['fold'] = folds\n",
    "print(f\"Assigned stratified {cv_fold}-fold indices. Fold sample counts:\")\n",
    "for f in range(cv_fold):\n",
    "    dist = train_df.loc[train_df['fold'] == f, 'has_cactus'].value_counts().to_dict()\n",
    "    print(f\"  Fold {f}: n={len(train_df[train_df['fold'] == f])} class dist={dist}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23e606da",
   "metadata": {},
   "source": [
    "## Model Training and Evaluation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "853b0c24",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_and_submission(train_df, train_id2path, test_img_ids, test_id2path, dropout_rate, class_weights, need_weights,\n",
    "                            BATCH_SIZE, N_WORKERS, cv_fold):\n",
    "    oof_true, oof_pred, fold_scores, fold_val_ids = [], [], [], []\n",
    "    for fold in range(cv_fold):\n",
    "        df_val = train_df[train_df['fold'] == fold].reset_index(drop=True)\n",
    "        val_img_ids = df_val['id'].tolist()\n",
    "        val_labels = df_val['has_cactus'].values\n",
    "        val_ds = CactusDataset(val_img_ids, val_labels, id2path=train_id2path, transforms=get_transforms(\"val\"))\n",
    "        val_loader = get_dataloader(val_ds, BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)\n",
    "        fold_model_path = os.path.join(MODEL_DIR, f\"efficientnet_b3_fold{fold}.pt\")\n",
    "        model = get_efficientnet_b3(dropout_rate=dropout_rate)\n",
    "        model.load_state_dict(torch.load(fold_model_path, map_location='cpu'))\n",
    "        model.to(DEVICE)\n",
    "        model.eval()\n",
    "        fold_class_weights = class_weights if need_weights else None\n",
    "        if fold_class_weights is not None:\n",
    "            fold_class_weights = torch.tensor(fold_class_weights).float().to(DEVICE)\n",
    "        loss_fn = nn.BCEWithLogitsLoss(reduction='none')\n",
    "        _, val_true, val_pred = eval_model(model, loss_fn, val_loader, DEVICE, fold_class_weights)\n",
    "        val_auc = roc_auc_score(val_true, val_pred)\n",
    "        oof_true.append(val_true)\n",
    "        oof_pred.append(val_pred)\n",
    "        fold_val_ids.append(val_img_ids)\n",
    "        fold_scores.append(val_auc)\n",
    "        print(f\"Reloaded fold {fold}, OOF Validation AUC={val_auc:.5f}\")\n",
    "\n",
    "    all_oof_true = np.concatenate(oof_true)\n",
    "    all_oof_pred = np.concatenate(oof_pred)\n",
    "    oof_auc = roc_auc_score(all_oof_true, all_oof_pred)\n",
    "    oof_cm = confusion_info(all_oof_true, all_oof_pred)\n",
    "    print(f\"OOF ROC-AUC (from loaded models): {oof_auc:.5f}\")\n",
    "    print(f\"OOF Confusion Matrix:\\n{oof_cm}\")\n",
    "\n",
    "    test_ds = CactusDataset(\n",
    "        test_img_ids, labels=None,\n",
    "        id2path=test_id2path,\n",
    "        transforms=get_transforms(\"val\")\n",
    "    )\n",
    "    test_loader = get_dataloader(test_ds, BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)\n",
    "    test_pred_list = []\n",
    "    for fold in range(cv_fold):\n",
    "        fold_model_path = os.path.join(MODEL_DIR, f\"efficientnet_b3_fold{fold}.pt\")\n",
    "        model = get_efficientnet_b3(dropout_rate=dropout_rate)\n",
    "        model.load_state_dict(torch.load(fold_model_path, map_location='cpu'))\n",
    "        model.to(DEVICE)\n",
    "        model.eval()\n",
    "        preds = []\n",
    "        with torch.no_grad():\n",
    "            for batch in test_loader:\n",
    "                images, img_ids = batch\n",
    "                images = images.to(DEVICE)\n",
    "                logits = model(images)\n",
    "                probs = torch.sigmoid(logits).cpu().numpy().reshape(-1)\n",
    "                preds.append(probs)\n",
    "        fold_test_pred = np.concatenate(preds)\n",
    "        test_pred_list.append(fold_test_pred)\n",
    "        print(f\"Loaded fold {fold} for test prediction.\")\n",
    "    test_probs = np.mean(test_pred_list, axis=0)\n",
    "\n",
    "    submission = pd.read_csv(SAMPLE_SUB_PATH)\n",
    "    submission['has_cactus'] = test_probs\n",
    "    submission.to_csv('submission.csv', index=False)\n",
    "    print(f\"Saved submission.csv in required format with {len(submission)} rows.\")\n",
    "\n",
    "    scores_df = pd.DataFrame({\n",
    "        'Model': [f\"efficientnet_b3_fold{f}\" for f in range(cv_fold)] + ['ensemble'],\n",
    "        'ROC-AUC': list(fold_scores) + [oof_auc]\n",
    "    })\n",
    "    scores_df.set_index('Model', inplace=True)\n",
    "    scores_df.to_csv(\"scores.csv\")\n",
    "    print(f\"Saved cross-validation scores to scores.csv\")\n",
    "\n",
    "def confusion_info(y_true, y_pred, threshold=0.5):\n",
    "    preds = (y_pred > threshold).astype(int)\n",
    "    cm = confusion_matrix(y_true, preds)\n",
    "    return cm\n",
    "\n",
    "@torch.no_grad()\n",
    "def eval_model(model, loss_fn, dataloader, device, class_weights):\n",
    "    model.eval()\n",
    "    y_true, y_pred = [], []\n",
    "    total_loss = 0.0\n",
    "    total_samples = 0\n",
    "    for batch in dataloader:\n",
    "        images, labels, _ = batch\n",
    "        images = images.to(device)\n",
    "        labels = labels.float().unsqueeze(1).to(device)\n",
    "        logits = model(images)\n",
    "        probs = torch.sigmoid(logits)\n",
    "        y_true.append(labels.cpu().numpy())\n",
    "        y_pred.append(probs.cpu().numpy())\n",
    "        if class_weights is not None:\n",
    "            weight = labels * class_weights[1] + (1 - labels) * class_weights[0]\n",
    "            loss = loss_fn(logits, labels)\n",
    "            loss = (loss * weight).mean()\n",
    "        else:\n",
    "            loss = loss_fn(logits, labels)\n",
    "        total_loss += loss.item() * labels.size(0)\n",
    "        total_samples += labels.size(0)\n",
    "    y_true = np.vstack(y_true).reshape(-1)\n",
    "    y_pred = np.vstack(y_pred).reshape(-1)\n",
    "    avg_loss = total_loss / total_samples\n",
    "    return avg_loss, y_true, y_pred\n",
    "\n",
    "def train_one_epoch(model, loss_fn, optimizer, scheduler, dataloader, device, class_weights):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    total_samples = 0\n",
    "    for batch in dataloader:\n",
    "        images, labels, _ = batch\n",
    "        images = images.to(device)\n",
    "        labels = labels.float().unsqueeze(1).to(device)\n",
    "        logits = model(images)\n",
    "        if class_weights is not None:\n",
    "            weight = labels * class_weights[1] + (1 - labels) * class_weights[0]\n",
    "            loss = loss_fn(logits, labels)\n",
    "            loss = (loss * weight).mean()\n",
    "        else:\n",
    "            loss = loss_fn(logits, labels)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if scheduler is not None:\n",
    "            scheduler.step()\n",
    "        total_loss += loss.item() * labels.size(0)\n",
    "        total_samples += labels.size(0)\n",
    "    avg_loss = total_loss / total_samples\n",
    "    return avg_loss\n",
    "\n",
    "def get_efficientnet_b3(dropout_rate=0.3):\n",
    "    model = timm.create_model('efficientnet_b3', pretrained=True)\n",
    "    n_in = model.classifier.in_features if hasattr(model, \"classifier\") else model.fc.in_features\n",
    "    model.classifier = nn.Sequential(\n",
    "        nn.Dropout(dropout_rate),\n",
    "        nn.Linear(n_in, 1)\n",
    "    )\n",
    "    return model\n",
    "\n",
    "def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4, pin_memory=True):\n",
    "    return DataLoader(\n",
    "        dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=shuffle,\n",
    "        num_workers=num_workers,\n",
    "        pin_memory=pin_memory\n",
    "    )\n",
    "\n",
    "def get_transforms(mode='train'):\n",
    "    # Correct Cutout: Albumentations v1.4.15 provides 'Cutout' as a class, but not always in the root.\n",
    "    # Defensive import; fallback to the most robust method for v1.4.15\n",
    "    imagenet_mean = [0.485, 0.456, 0.406]\n",
    "    imagenet_std = [0.229, 0.224, 0.225]\n",
    "    if mode == 'train':\n",
    "        min_frac, max_frac = 0.05, 0.2\n",
    "        min_cut = int(300 * min_frac)\n",
    "        max_cut = int(300 * max_frac)\n",
    "        # There is no A.Cutout in v1.4.15 root, but A.augmentations.transforms.Cutout exists.\n",
    "        try:\n",
    "            from albumentations.augmentations.transforms import Cutout\n",
    "            have_cutout = True\n",
    "        except ImportError:\n",
    "            have_cutout = False\n",
    "        this_cut_h = random.randint(min_cut, max_cut)\n",
    "        this_cut_w = random.randint(min_cut, max_cut)\n",
    "        cutout_fill = [int(255 * m) for m in imagenet_mean]\n",
    "        tforms = [\n",
    "            A.RandomResizedCrop(300, 300, scale=(0.7, 1.0), ratio=(0.8, 1.2), p=1.0),\n",
    "            A.Rotate(limit=30, p=0.8),\n",
    "        ]\n",
    "        if have_cutout:\n",
    "            tforms.append(\n",
    "                Cutout(\n",
    "                    num_holes=1,\n",
    "                    max_h_size=this_cut_h,\n",
    "                    max_w_size=this_cut_w,\n",
    "                    fill_value=cutout_fill,  # RGB image in albumentations requires [R,G,B]\n",
    "                    always_apply=False,\n",
    "                    p=0.7\n",
    "                )\n",
    "            )\n",
    "        else:\n",
    "            # No available Cutout, so fallback to no cutout but emit warning\n",
    "            print(\"WARNING: albumentations.Cutout not found, continuing without Cutout augmentation\")\n",
    "        tforms.extend([\n",
    "            A.RandomContrast(limit=0.2, p=0.5),\n",
    "            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.1),\n",
    "            A.Normalize(mean=imagenet_mean, std=imagenet_std, max_pixel_value=255.0),\n",
    "            ToTensorV2()\n",
    "        ])\n",
    "        return A.Compose(tforms)\n",
    "    else:\n",
    "        return A.Compose([\n",
    "            A.Resize(300, 300),\n",
    "            A.Normalize(mean=imagenet_mean, std=imagenet_std, max_pixel_value=255.0),\n",
    "            ToTensorV2()\n",
    "        ])\n",
    "\n",
    "print(\"Section: Model Training and Evaluation\")\n",
    "dropout_rate = round(random.uniform(0.2, 0.5), 2)\n",
    "print(f\"Model config: EfficientNet-B3, Image size 300, Head dropout={dropout_rate}\")\n",
    "\n",
    "if DEBUG:\n",
    "    print(\"DEBUG mode: using 10% subsample and 1 epoch (per fold)\")\n",
    "    sample_frac = 0.10\n",
    "    sampled_idxs = []\n",
    "    for f in range(cv_fold):\n",
    "        fold_idx = train_df.index[train_df['fold'] == f].tolist()\n",
    "        fold_labels = train_df.loc[fold_idx, 'has_cactus'].values\n",
    "        idx_pos = [i for i, l in zip(fold_idx, fold_labels) if l == 1]\n",
    "        idx_neg = [i for i, l in zip(fold_idx, fold_labels) if l == 0]\n",
    "        n_pos = max(1, int(sample_frac * len(idx_pos)))\n",
    "        n_neg = max(1, int(sample_frac * len(idx_neg)))\n",
    "        if len(idx_pos) > 0:\n",
    "            sampled_idxs += np.random.choice(idx_pos, n_pos, replace=False).tolist()\n",
    "        if len(idx_neg) > 0:\n",
    "            sampled_idxs += np.random.choice(idx_neg, n_neg, replace=False).tolist()\n",
    "    train_df = train_df.loc[sampled_idxs].reset_index(drop=True)\n",
    "    print(f\"DEBUG subsample shape: {train_df.shape}\")\n",
    "    debug_epochs = 1\n",
    "else:\n",
    "    debug_epochs = None\n",
    "\n",
    "BATCH_SIZE = 64 if torch.cuda.is_available() else 32\n",
    "N_WORKERS = 4 if torch.cuda.is_available() else 1\n",
    "EPOCHS = 20 if not DEBUG else debug_epochs\n",
    "MIN_EPOCHS = 5 if not DEBUG else 1\n",
    "EARLY_STOP_PATIENCE = 7 if not DEBUG else 2\n",
    "LR = 1e-3\n",
    "\n",
    "model_files = [os.path.join(MODEL_DIR, f\"efficientnet_b3_fold{f}.pt\") for f in range(cv_fold)]\n",
    "if all([os.path.exists(f) for f in model_files]):\n",
    "    print(\"All fold models found in models/. Running inference and file saving only (no retrain).\")\n",
    "    inference_and_submission(train_df, train_id2path, test_img_ids, test_id2path, dropout_rate,\n",
    "                            class_weights, need_weights, BATCH_SIZE, N_WORKERS, cv_fold)\n",
    "    return\n",
    "\n",
    "oof_true, oof_pred, fold_scores, fold_val_ids = [], [], [], []\n",
    "start_time = time.time() if DEBUG else None\n",
    "\n",
    "for fold in range(cv_fold):\n",
    "    print(f\"\\n=== FOLD {fold} TRAINING ===\")\n",
    "    df_train = train_df[train_df['fold'] != fold].reset_index(drop=True)\n",
    "    df_val = train_df[train_df['fold'] == fold].reset_index(drop=True)\n",
    "    print(f\"Train size: {df_train.shape[0]}, Val size: {df_val.shape[0]}\")\n",
    "    train_img_ids = df_train['id'].tolist()\n",
    "    train_labels = df_train['has_cactus'].values\n",
    "    val_img_ids = df_val['id'].tolist()\n",
    "    val_labels = df_val['has_cactus'].values\n",
    "\n",
    "    train_ds = CactusDataset(\n",
    "        train_img_ids, train_labels,\n",
    "        id2path=train_id2path,\n",
    "        transforms=get_transforms(\"train\")\n",
    "    )\n",
    "    val_ds = CactusDataset(\n",
    "        val_img_ids, val_labels,\n",
    "        id2path=train_id2path,\n",
    "        transforms=get_transforms(\"val\")\n",
    "    )\n",
    "    train_loader = get_dataloader(train_ds, BATCH_SIZE, shuffle=True, num_workers=N_WORKERS)\n",
    "    val_loader = get_dataloader(val_ds, BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)\n",
    "    model = get_efficientnet_b3(dropout_rate=dropout_rate)\n",
    "    model.to(DEVICE)\n",
    "    loss_fn = nn.BCEWithLogitsLoss(reduction='none')\n",
    "    optimizer = optim.AdamW(model.parameters(), lr=LR)\n",
    "    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)\n",
    "    fold_class_weights = class_weights if need_weights else None\n",
    "    if fold_class_weights is not None:\n",
    "        fold_class_weights = torch.tensor(fold_class_weights).float().to(DEVICE)\n",
    "    best_auc = -np.inf\n",
    "    best_epoch = -1\n",
    "    best_model_state = None\n",
    "    patience = 0\n",
    "\n",
    "    for epoch in range(EPOCHS):\n",
    "        train_loss = train_one_epoch(\n",
    "            model, loss_fn, optimizer, scheduler, train_loader, DEVICE, fold_class_weights)\n",
    "        val_loss, val_true, val_pred = eval_model(\n",
    "            model, loss_fn, val_loader, DEVICE, fold_class_weights)\n",
    "        val_auc = roc_auc_score(val_true, val_pred)\n",
    "        cm = confusion_info(val_true, val_pred)\n",
    "        print(f\"Epoch {epoch+1:02d}: train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_auc={val_auc:.4f}\")\n",
    "        print(f\" Val confusion_matrix (rows:true [0,1]; cols:pred [0,1]):\\n{cm}\")\n",
    "        if val_auc > best_auc:\n",
    "            best_auc = val_auc\n",
    "            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
    "            best_epoch = epoch\n",
    "            patience = 0\n",
    "        else:\n",
    "            patience += 1\n",
    "        if DEBUG and epoch + 1 >= debug_epochs:\n",
    "            break\n",
    "        if (epoch + 1) >= MIN_EPOCHS and patience >= EARLY_STOP_PATIENCE:\n",
    "            print(f\"Early stopping at epoch {epoch+1}, best_epoch={best_epoch+1}.\")\n",
    "            break\n",
    "\n",
    "    model.load_state_dict(best_model_state)\n",
    "    fold_model_path = os.path.join(MODEL_DIR, f\"efficientnet_b3_fold{fold}.pt\")\n",
    "    torch.save(model.state_dict(), fold_model_path)\n",
    "    print(f\"Saved best model for fold {fold} at {fold_model_path} (best_auc={best_auc:.5f}, best_epoch={best_epoch+1})\")\n",
    "\n",
    "    _, val_true, val_pred = eval_model(model, loss_fn, val_loader, DEVICE, fold_class_weights)\n",
    "    oof_true.append(val_true)\n",
    "    oof_pred.append(val_pred)\n",
    "    fold_val_ids.append(val_img_ids)\n",
    "    fold_scores.append(best_auc)\n",
    "    print(f\"OOF stored for fold {fold}, Validation AUC={best_auc:.5f}\")\n",
    "\n",
    "end_time = time.time() if DEBUG else None\n",
    "if DEBUG:\n",
    "    debug_time = end_time - start_time\n",
    "    estimated_time = (1 / 0.1) * (EPOCHS / debug_epochs) * debug_time\n",
    "    print(\"=== Start of Debug Information ===\")\n",
    "    print(f\"debug_time: {debug_time:.1f}\")\n",
    "    print(f\"estimated_time: {estimated_time:.1f}\")\n",
    "    print(\"=== End of Debug Information ===\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3f0269e",
   "metadata": {},
   "source": [
    "## Ensemble Strategy and Final Predictions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "308dcdb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Section: Ensemble Strategy and Final Predictions\")\n",
    "all_oof_true = np.concatenate(oof_true)\n",
    "all_oof_pred = np.concatenate(oof_pred)\n",
    "oof_auc = roc_auc_score(all_oof_true, all_oof_pred)\n",
    "oof_cm = confusion_info(all_oof_true, all_oof_pred)\n",
    "print(f\"OOF ROC-AUC: {oof_auc:.5f}\")\n",
    "print(f\"OOF Confusion Matrix:\\n{oof_cm}\")\n",
    "\n",
    "test_ds = CactusDataset(\n",
    "    test_img_ids, labels=None,\n",
    "    id2path=test_id2path,\n",
    "    transforms=get_transforms(\"val\")\n",
    ")\n",
    "test_loader = get_dataloader(test_ds, BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)\n",
    "test_pred_list = []\n",
    "for fold in range(cv_fold):\n",
    "    fold_model_path = os.path.join(MODEL_DIR, f\"efficientnet_b3_fold{fold}.pt\")\n",
    "    model = get_efficientnet_b3(dropout_rate=dropout_rate)\n",
    "    model.load_state_dict(torch.load(fold_model_path, map_location='cpu'))\n",
    "    model.to(DEVICE)\n",
    "    model.eval()\n",
    "    preds = []\n",
    "    with torch.no_grad():\n",
    "        for batch in test_loader:\n",
    "            images, img_ids = batch\n",
    "            images = images.to(DEVICE)\n",
    "            logits = model(images)\n",
    "            probs = torch.sigmoid(logits).cpu().numpy().reshape(-1)\n",
    "            preds.append(probs)\n",
    "    fold_test_pred = np.concatenate(preds)\n",
    "    test_pred_list.append(fold_test_pred)\n",
    "    print(f\"Loaded fold {fold} for test prediction.\")\n",
    "test_probs = np.mean(test_pred_list, axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58b5ded8",
   "metadata": {},
   "source": [
    "## Submission File Generation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "988914c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Section: Submission File Generation\")\n",
    "submission = pd.read_csv(SAMPLE_SUB_PATH)\n",
    "submission['has_cactus'] = test_probs\n",
    "submission.to_csv('submission.csv', index=False)\n",
    "print(f\"Saved submission.csv in required format with {len(submission)} rows.\")\n",
    "\n",
    "scores_df = pd.DataFrame({\n",
    "    'Model': [f\"efficientnet_b3_fold{f}\" for f in range(cv_fold)] + ['ensemble'],\n",
    "    'ROC-AUC': list(fold_scores) + [oof_auc]\n",
    "})\n",
    "scores_df.set_index('Model', inplace=True)\n",
    "scores_df.to_csv(\"scores.csv\")\n",
    "print(f\"Saved cross-validation scores to scores.csv\")"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
